{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# \ud83d\udcc3 Solution for Exercise M5.01\n",
    "\n",
    "In the previous notebook, we showed how a tree with 1 level depth works. The\n",
    "aim of this exercise is to repeat part of the previous experiment for a tree\n",
    "with 2 levels depth to show how such parameter affects the feature space\n",
    "partitioning.\n",
    "\n",
    "We first load the penguins dataset and split it into a training and a testing\n",
    "sets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "penguins = pd.read_csv(\"../datasets/penguins_classification.csv\")\n",
    "culmen_columns = [\"Culmen Length (mm)\", \"Culmen Depth (mm)\"]\n",
    "target_column = \"Species\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"admonition note alert alert-info\">\n",
    "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
    "<p class=\"last\">If you want a deeper overview regarding this dataset, you can refer to the\n",
    "Appendix - Datasets description section at the end of this MOOC.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "data, target = penguins[culmen_columns], penguins[target_column]\n",
    "data_train, data_test, target_train, target_test = train_test_split(\n",
    "    data, target, random_state=0\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a decision tree classifier with a maximum depth of 2 levels and fit the\n",
    "training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "tree = DecisionTreeClassifier(max_depth=2)\n",
    "tree.fit(data_train, target_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now plot the data and the decision boundary of the trained classifier to see\n",
    "the effect of increasing the depth of the tree.\n",
    "\n",
    "Hint: Use the class `DecisionBoundaryDisplay` from the module\n",
    "`sklearn.inspection` as shown in previous course notebooks.\n",
    "\n",
    "<div class=\"admonition warning alert alert-danger\">\n",
    "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n",
    "<p class=\"last\">At this time, it is not possible to use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict_proba\"</span></tt> for\n",
    "multiclass problems on a single plot. This is a planned feature for a future\n",
    "version of scikit-learn. In the mean time, you can use\n",
    "<tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt> instead.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import seaborn as sns\n",
    "\n",
    "from sklearn.inspection import DecisionBoundaryDisplay\n",
    "\n",
    "\n",
    "tab10_norm = mpl.colors.Normalize(vmin=-0.5, vmax=8.5)\n",
    "\n",
    "palette = [\"tab:blue\", \"tab:green\", \"tab:orange\"]\n",
    "DecisionBoundaryDisplay.from_estimator(\n",
    "    tree,\n",
    "    data_train,\n",
    "    response_method=\"predict\",\n",
    "    cmap=\"tab10\",\n",
    "    norm=tab10_norm,\n",
    "    alpha=0.5,\n",
    ")\n",
    "ax = sns.scatterplot(\n",
    "    data=penguins,\n",
    "    x=culmen_columns[0],\n",
    "    y=culmen_columns[1],\n",
    "    hue=target_column,\n",
    "    palette=palette,\n",
    ")\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc=\"upper left\")\n",
    "_ = plt.title(\"Decision boundary using a decision tree\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Did we make use of the feature \"Culmen Length\"? Plot the tree using the\n",
    "function `sklearn.tree.plot_tree` to find out!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution\n",
    "from sklearn.tree import plot_tree\n",
    "\n",
    "_, ax = plt.subplots(figsize=(16, 12))\n",
    "_ = plot_tree(\n",
    "    tree,\n",
    "    feature_names=culmen_columns,\n",
    "    class_names=tree.classes_.tolist(),\n",
    "    impurity=False,\n",
    "    ax=ax,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "solution"
    ]
   },
   "source": [
    "The resulting tree has 7 nodes: 3 of them are \"split nodes\" and 4 are \"leaf\n",
    "nodes\" (or simply \"leaves\"), organized in 2 levels. We see that the second\n",
    "tree level used the \"Culmen Length\" to make two new decisions. Qualitatively,\n",
    "we saw that such a simple tree was enough to classify the penguins' species."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the accuracy of the decision tree on the testing data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution\n",
    "test_score = tree.fit(data_train, target_train).score(data_test, target_test)\n",
    "print(f\"Accuracy of the DecisionTreeClassifier: {test_score:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "solution"
    ]
   },
   "source": [
    "At this stage, we have the intuition that a decision tree is built by\n",
    "successively partitioning the feature space, considering one feature at a\n",
    "time.\n",
    "\n",
    "We predict an Adelie penguin if the feature value is below the threshold,\n",
    "which is not surprising since this partition was almost pure. If the feature\n",
    "value is above the threshold, we predict the Gentoo penguin, the class that is\n",
    "most probable.\n",
    "\n",
    "## (Estimated) predicted probabilities in multi-class problems\n",
    "\n",
    "For those interested, one can further try to visualize the output of\n",
    "`predict_proba` for a multiclass problem using `DecisionBoundaryDisplay`,\n",
    "except that for a K-class problem you have K probability outputs for each\n",
    "data point. Visualizing all these on a single plot can quickly become tricky\n",
    "to interpret. It is then common to instead produce K separate plots, one for\n",
    "each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by\n",
    "calling `DecisionBoundaryDisplay` several times, once for each class, and\n",
    "passing the `class_of_interest` parameter to the function.\n",
    "\n",
    "For example, in the plot below, the first plot on the left shows the\n",
    "certainty of classifying a data point as belonging to the \"Adelie\" class. The\n",
    "darker the color, the more certain the model is that a given point in the\n",
    "feature space belongs to a given class. The same logic applies to the other\n",
    "plots in the figure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "solution"
    ]
   },
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "\n",
    "_, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))\n",
    "plt.suptitle(\"Predicted probabilities for decision tree model\", y=1.05)\n",
    "plt.subplots_adjust(bottom=0.45)\n",
    "\n",
    "for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)):\n",
    "    ax.set_title(f\"Class {class_of_interest}\")\n",
    "    DecisionBoundaryDisplay.from_estimator(\n",
    "        tree,\n",
    "        data_test,\n",
    "        response_method=\"predict_proba\",\n",
    "        class_of_interest=class_of_interest,\n",
    "        ax=ax,\n",
    "        vmin=0,\n",
    "        vmax=1,\n",
    "        cmap=\"Blues\",\n",
    "    )\n",
    "    ax.scatter(\n",
    "        data_test[\"Culmen Length (mm)\"].loc[target_test == class_of_interest],\n",
    "        data_test[\"Culmen Depth (mm)\"].loc[target_test == class_of_interest],\n",
    "        marker=\"o\",\n",
    "        c=\"w\",\n",
    "        edgecolor=\"k\",\n",
    "    )\n",
    "    ax.set_xlabel(\"Culmen Length (mm)\")\n",
    "    if idx == 0:\n",
    "        ax.set_ylabel(\"Culmen Depth (mm)\")\n",
    "\n",
    "ax = plt.axes([0.15, 0.14, 0.7, 0.05])\n",
    "plt.colorbar(cm.ScalarMappable(cmap=\"Blues\"), cax=ax, orientation=\"horizontal\")\n",
    "_ = ax.set_title(\"Predicted class membership probability\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "solution"
    ]
   },
   "source": [
    "\n",
    "<div class=\"admonition note alert alert-info\">\n",
    "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
    "<p class=\"last\">You may notice that we do not use a diverging colormap (2 color gradients with\n",
    "white in the middle). Indeed, in a multiclass setting, 0.5 is not a\n",
    "meaningful value, hence using white as the center of the colormap is not\n",
    "appropriate. Instead, we use a sequential colormap, where the color intensity\n",
    "indicates the certainty of the classification. The darker the color, the more\n",
    "certain the model is that a given point in the feature space belongs to a\n",
    "given class.</p>\n",
    "</div>"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}